{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n",
      "1 Physical GPUs, 1 Logical GPUs\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import tensorflow as tf\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "from pgn import PGN\n",
    "from loss import coverage_loss, mask_coverage_loss\n",
    "from batcher import batcher\n",
    "from utils.config import CKPT_DIR\n",
    "from utils.saveLoader import Vocab\n",
    "from utils.params import get_default_params\n",
    "from utils.config_gpu import config_gpu\n",
    "config_gpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 Physical GPUs, 1 Logical GPUs\n",
      "Building vocab ...\n",
      "Building the model ...\n",
      "Creating the checkpoint manager\n",
      "Initializing from scratch.\n",
      "Starting the training ...\n",
      "Epoch 1 Batch 1 Loss 3.4151 norm_loss 3.0928 cover_loss 0.3224\n",
      "Epoch 1 Batch 2 Loss 5.9242 norm_loss 5.3650 cover_loss 0.5592\n",
      "Epoch 1 Batch 3 Loss 6.8290 norm_loss 6.1843 cover_loss 0.6447\n",
      "Epoch 1 Batch 4 Loss 4.4605 norm_loss 4.0395 cover_loss 0.4211\n",
      "Epoch 1 Batch 5 Loss 6.4113 norm_loss 5.8060 cover_loss 0.6053\n",
      "Epoch 1 Batch 6 Loss 4.6694 norm_loss 4.2286 cover_loss 0.4408\n",
      "Epoch 1 Batch 7 Loss 5.5075 norm_loss 4.9878 cover_loss 0.5197\n",
      "Epoch 1 Batch 8 Loss 5.5762 norm_loss 5.0499 cover_loss 0.5263\n",
      "Epoch 1 Batch 9 Loss 6.6233 norm_loss 5.9983 cover_loss 0.6250\n",
      "Epoch 1 Batch 10 Loss 4.4605 norm_loss 4.0395 cover_loss 0.4211\n",
      "Epoch 1 Batch 11 Loss 3.3450 norm_loss 3.0292 cover_loss 0.3158\n",
      "my test break\n",
      "共花了 80.83 s\n"
     ]
    }
   ],
   "source": [
    "%run ../utils/train.py --batch_size 4 --test_run True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = get_default_params()\n",
    "model = PGN(params)\n",
    "vocab = Vocab(params[\"vocab_path\"], params[\"vocab_size\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def loss_function(real, pred, padding_mask):\n",
    "#     loss = 0\n",
    "#     for t in range(real.shape[1]):\n",
    "#         # if padding_mask:\n",
    "#         try:\n",
    "#             # real[:,t] shape (batch_size, ) (32, ) 表示这批样本的第t个词\n",
    "#             # pred[:, t, :] shape （batch_size, vocab_size) (32, 32252) 表示这批样本，第t个词的概率分布\n",
    "#             loss_ = loss_object(real[:, t], pred[:, t, :])\n",
    "#             mask = tf.cast(padding_mask[:, t], dtype=loss_.dtype)\n",
    "#             loss_ *= mask\n",
    "#             loss_ = tf.reduce_mean(loss_, axis=0)  # batch-wise\n",
    "#             loss += loss_\n",
    "#         # else:\n",
    "#         except:\n",
    "#             loss_ = loss_object(real[:, t], pred[:, t, :])\n",
    "#             loss_ = tf.reduce_mean(loss_, axis=0)  # batch-wise\n",
    "#             loss += loss_\n",
    "#     return tf.reduce_mean(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = batcher(vocab, params)\n",
    "encoder_batch_data, decoder_batch_data = next(iter(dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "enc_inp = encoder_batch_data[\"enc_input\"]\n",
    "extended_enc_input = encoder_batch_data[\"extended_enc_input\"]\n",
    "max_oov_len = encoder_batch_data[\"max_oov_len\"]\n",
    "dec_input = decoder_batch_data[\"dec_input\"]\n",
    "dec_target = decoder_batch_data[\"dec_target\"]\n",
    "cov_loss_wt=1\n",
    "enc_pad_mask=encoder_batch_data[\"sample_encoder_pad_mask\"]\n",
    "padding_mask=decoder_batch_data[\"sample_decoder_pad_mask\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "enc_output, enc_hidden = model.call_encoder(enc_inp)\n",
    "dec_hidden = enc_hidden"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions, _, attentions, coverages = model(dec_input,\n",
    "                                              dec_hidden,\n",
    "                                              enc_output,\n",
    "                                              dec_target,\n",
    "                                              extended_enc_input,\n",
    "                                              max_oov_len,\n",
    "                                              enc_pad_mask=enc_pad_mask,\n",
    "                                              use_coverage=True,\n",
    "                                              prev_coverage=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义损失函数\n",
    "loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')\n",
    "def loss_function(real, pred, padding_mask):\n",
    "    # <PAD> 和 <UNK> 的损失都不算\n",
    "    loss_ = loss_object(real, pred)\n",
    "\n",
    "    mask = tf.cast(padding_mask, dtype=loss_.dtype)\n",
    "    loss_ *= mask\n",
    "    return tf.reduce_mean(loss_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=14123, shape=(), dtype=float32, numpy=5.425539>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from loss import coverage_loss\n",
    "loss_function(dec_target, predictions, padding_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=14341, shape=(), dtype=float32, numpy=0.9807154>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cl = cov_loss_wt * coverage_loss(attentions, coverages, padding_mask)\n",
    "cl"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "进入coverage_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=15089, shape=(), dtype=float32, numpy=0.9807154>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "attn_dists = attentions\n",
    "pad_mask=False\n",
    "\n",
    "# shape (batch_size, max_len_x). Initial coverage is zero.\n",
    "coverage = tf.zeros_like(attn_dists[0])\n",
    "# Coverage loss per decoder timestep. Will be list length max_dec_steps containing shape (batch_size).\n",
    "cover_losses = []\n",
    "for t, a in enumerate(attn_dists):\n",
    "    cover_loss = tf.reduce_sum(tf.minimum(a, coverage), [1])  # calculate the coverage loss for this step\n",
    "    if pad_mask:\n",
    "        mask = tf.cast(padding_mask[:,t], dtype=cover_loss.dtype)\n",
    "        cover_losses.append(cover_loss * mask)\n",
    "    else:\n",
    "        cover_losses.append(cover_loss)\n",
    "    # cover_losses.append(cover_loss)\n",
    "    # update the coverage vector\n",
    "    # batch size ,max_x_len\n",
    "    coverage += a\n",
    "loss = tf.reduce_mean(cover_losses)\n",
    "loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "cover_losses = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "cover,attn = coverages[0], attn_dists[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "cover = tf.squeeze(cover)\n",
    "# attn =  tf.squeeze(attn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(TensorShape([64, 460]), TensorShape([64, 460]))"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cover.shape,attn.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=15129, shape=(64,), dtype=float32, numpy=\n",
       "array([1.0000001 , 0.99999976, 1.        , 0.9999999 , 1.0000002 ,\n",
       "       1.0000001 , 1.0000001 , 0.99999994, 0.99999994, 1.        ,\n",
       "       1.        , 0.99999976, 1.0000001 , 1.0000001 , 0.99999994,\n",
       "       1.        , 0.9999998 , 1.0000001 , 0.9999999 , 0.9999999 ,\n",
       "       0.99999976, 1.0000002 , 1.        , 0.9999999 , 0.9999999 ,\n",
       "       1.0000001 , 1.        , 0.9999999 , 1.        , 1.        ,\n",
       "       1.        , 1.        , 1.0000001 , 0.9999998 , 0.9999999 ,\n",
       "       0.99999994, 0.9999999 , 0.99999976, 1.        , 0.9999998 ,\n",
       "       1.0000001 , 0.9999999 , 0.9999997 , 1.        , 0.99999976,\n",
       "       1.        , 0.99999994, 0.9999999 , 1.        , 0.99999994,\n",
       "       1.        , 0.9999999 , 0.99999976, 1.0000001 , 1.        ,\n",
       "       1.        , 1.        , 0.99999964, 1.        , 1.        ,\n",
       "       1.        , 1.        , 0.9999998 , 0.99999994], dtype=float32)>"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.reduce_sum(tf.minimum(cover, attn), axis=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=17488, shape=(), dtype=float32, numpy=0.52283657>"
      ]
     },
     "execution_count": 101,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from loss import coverage_loss\n",
    "coverage_loss(attn_dists, coverages, padding_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [],
   "source": [
    "cover_losses = []\n",
    "for cover, attn in zip(coverages, attn_dists):\n",
    "    # cover, atten shape (batch_size, enc_len)\n",
    "    cover = tf.squeeze(cover)\n",
    "    # cover_loss_ shape (batch_size)\n",
    "    cover_loss_ = tf.reduce_sum(tf.minimum(cover, attn), axis=-1)\n",
    "    cover_losses.append(cover_loss_)\n",
    "    \n",
    "cover_losses = tf.stack(cover_losses, 1)\n",
    "mask = tf.cast(padding_mask, dtype=cover_loss_.dtype)\n",
    "cover_losses *= mask\n",
    "loss = tf.reduce_mean(cover_losses)\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=16405, shape=(), dtype=float32, numpy=0.52283657>"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.reduce_mean(cover_losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=16399, shape=(), dtype=float32, numpy=27.1875>"
      ]
     },
     "execution_count": 85,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss = tf.reduce_sum(tf.reduce_mean(cover_losses, axis=0))\n",
    "loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "配置保存点管理器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Restored from E:\\GitHub\\QA-abstract-and-reasoning\\data\\checkpoints\\training_checkpoints\\test_model-1\n"
     ]
    }
   ],
   "source": [
    "# checkpoint = tf.train.Checkpoint(Seq2Seq=model)\n",
    "# checkpoint_manager = tf.train.CheckpointManager(checkpoint, CKPT_DIR, max_to_keep=5)\n",
    "# checkpoint.restore(checkpoint_manager.latest_checkpoint)\n",
    "# if checkpoint_manager.latest_checkpoint:\n",
    "#     print(\"Restored from {}\".format(checkpoint_manager.latest_checkpoint))\n",
    "# else:\n",
    "#     print(\"Initializing from scratch.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from utils.train_helper import train_model\n",
    "# train_model(model, vocab, params, checkpoint_manager)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "20736"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "params['max_train_steps']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'E:\\\\GitHub\\\\QA-abstract-and-reasoning\\\\data\\\\checkpoints\\\\training_checkpoints\\\\test_model-1'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "checkpoint_prefix = os.path.join(CKPT_DIR, \"test_model\")\n",
    "checkpoint.save(file_prefix=checkpoint_prefix)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 一探train_model究竟\n",
    "\n",
    "- 摒弃`@tf.funtion`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = params['epochs']\n",
    "optimizer = tf.keras.optimizers.Adagrad(params['learning_rate'],\n",
    "                                            initial_accumulator_value=params['adagrad_init_acc'],\n",
    "                                            clipnorm=params['max_grad_norm'])\n",
    "loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')\n",
    "# optimizer.get_config(), loss_object.get_config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_function(real, pred, padding_mask):\n",
    "    loss = 0\n",
    "    for t in range(real.shape[1]):\n",
    "        # if padding_mask:\n",
    "        try:\n",
    "            loss_ = loss_object(real[:, t], pred[:, t, :])\n",
    "            mask = tf.cast(padding_mask[:, t], dtype=loss_.dtype)\n",
    "            loss_ *= mask\n",
    "            loss_ = tf.reduce_mean(loss_, axis=0)  # batch-wise\n",
    "            loss += loss_\n",
    "        # else:\n",
    "        except:\n",
    "            loss_ = loss_object(real[:, t], pred[:, t, :])\n",
    "            loss_ = tf.reduce_mean(loss_, axis=0)  # batch-wise\n",
    "            loss += loss_\n",
    "    return tf.reduce_mean(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @tf.funtion\n",
    "def train_step(enc_inp, extended_enc_input, max_oov_len,\n",
    "               dec_input, dec_target, cov_loss_wt,\n",
    "               enc_pad_mask, padding_mask=None):\n",
    "    batch_loss = 0\n",
    "    with tf.GradientTape() as tape:\n",
    "        enc_output, enc_hidden = model.call_encoder(enc_inp)\n",
    "        # 第一个隐藏层输入\n",
    "        dec_hidden = enc_hidden\n",
    "        # 逐个预测序列\n",
    "        predictions, _, attentions, coverages = model(dec_input,\n",
    "                                                      dec_hidden,\n",
    "                                                      enc_output,\n",
    "                                                      dec_target,\n",
    "                                                      extended_enc_input,\n",
    "                                                      max_oov_len,\n",
    "                                                      enc_pad_mask=enc_pad_mask,\n",
    "                                                      use_coverage=True,\n",
    "                                                      prev_coverage=None)\n",
    "    \n",
    "        batch_loss = loss_function(dec_target, predictions, padding_mask) + \\\n",
    "                     cov_loss_wt * coverage_loss(attentions, coverages, padding_mask)\n",
    "        \n",
    "#     variables = model.encoder.trainable_variables + model.decoder.trainable_variables + \\\n",
    "#                 model.attention.trainable_variables + model.pointer.trainable_variables\n",
    "    variables = model.trainable_variables\n",
    "    gradients = tape.gradient(batch_loss, variables)\n",
    "    optimizer.apply_gradients(zip(gradients, variables))\n",
    "\n",
    "    return batch_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.time()\n",
    "dataset = batcher(vocab, params)\n",
    "total_loss = 0\n",
    "step = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 进入for 循环\n",
    "# for encoder_batch_data, decoder_batch_data in dataset:\n",
    "encoder_batch_data, decoder_batch_data = next(iter(dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_loss = train_step(enc_inp = encoder_batch_data[\"enc_input\"],\n",
    "                        extended_enc_input = encoder_batch_data[\"extended_enc_input\"],\n",
    "                        max_oov_len = encoder_batch_data[\"max_oov_len\"],\n",
    "                        dec_input = decoder_batch_data[\"dec_input\"],\n",
    "                        dec_target = decoder_batch_data[\"dec_target\"],\n",
    "                        cov_loss_wt=0.5,\n",
    "                        enc_pad_mask=encoder_batch_data[\"sample_encoder_pad_mask\"],\n",
    "                        padding_mask=decoder_batch_data[\"sample_decoder_pad_mask\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "进入train_step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:tf2.0]",
   "language": "python",
   "name": "conda-env-tf2.0-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
